# -*- coding:utf-8 -*-

import numpy as np
import config.glob.config as global_cfg
from base.main import predict
from tensorflow.examples.tutorials.mnist import input_data


"""
一行代码预测
"""

if __name__ == '__main__':
    # 抽取两条
    mnist_dir = global_cfg.DATASET_ROOT + '\\mnist\\'
    mnist_train = input_data.read_data_sets(mnist_dir, one_hot=True).test
    x, y = mnist_train.next_batch(2)
    x = np.reshape(x, [-1, 28, 28, 1])
    # 一行预测
    y_pred = predict('config/core/predict/lenet5.yml', x)
    # 打印结果
    print('\033[35my:{}, y_pred:{}'.format(np.argmax(y, axis=1), np.argmax(y_pred, axis=1)))
